#
# BSD 3-Clause License
#
# Copyright (c) 2017 xxxx
# All rights reserved.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============================================================================
#
from visdom import Visdom
import json 

class Visualizer(object):
    """ Visualizer
    """
    def __init__(self, port='13579', env='main', id=None):
        self.cur_win = {}
        self.vis = Visdom(port=port, env=env)
        self.id = id
        self.env = env
        # Restore
        ori_win = self.vis.get_window_data()
        ori_win = json.loads(ori_win)
        #print(ori_win)
        self.cur_win = { v['title']: k for k, v in ori_win.items()  }

    def vis_scalar(self, name, x, y, opts=None):
        if not isinstance(x, list):
            x = [x]
        if not isinstance(y, list):
            y = [y]
        
        if self.id is not None:
            name = "[%s]"%self.id + name
        default_opts = { 'title': name }
        if opts is not None:
            default_opts.update(opts)

        win = self.cur_win.get(name, None)
        if win is not None:
            self.vis.line( X=x, Y=y, opts=default_opts, update='append',win=win )
        else:
            self.cur_win[name] = self.vis.line( X=x, Y=y, opts=default_opts)

    def vis_image(self, name, img, env=None, opts=None):
        """ vis image in visdom
        """
        if env is None:
            env = self.env 
        if self.id is not None:
            name = "[%s]"%self.id + name
        win = self.cur_win.get(name, None)
        default_opts = { 'title': name }
        if opts is not None:
                default_opts.update(opts)
        if win is not None:
            self.vis.image( img=img, win=win, opts=opts, env=env )
        else:
            self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env )
    
    def vis_table(self, name, tbl, opts=None):
        win = self.cur_win.get(name, None)

        tbl_str = "<table width=\"100%\"> "
        tbl_str+="<tr> \
                 <th>Term</th> \
                 <th>Value</th> \
                 </tr>"
        for k, v in tbl.items():
            tbl_str+=  "<tr> \
                       <td>%s</td> \
                       <td>%s</td> \
                       </tr>"%(k, v)

        tbl_str+="</table>"

        default_opts = { 'title': name }
        if opts is not None:
                default_opts.update(opts)
        if win is not None:
            self.vis.text(tbl_str, win=win, opts=default_opts)
        else:
            self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts)


if __name__=='__main__':
    import numpy as np
    vis = Visualizer(port=13500, env='main')
    tbl = {"lr": 214, "momentum": 0.9}
    vis.vis_table("test_table", tbl)
    tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"}
    vis.vis_table("test_table", tbl)
    